#!/usr/bin/env python
import numpy as np
import pandas as pd
import sys
import os

# Midpoints of the R_G bins (kpc)
RG_MID = {
    "1.5–3.0": 2.25,
    "3.0–5.0": 4.0,
    "5.0–8.0": 6.5,
    "8.0–12.0": 10.0,
}

# Ordered stellar-mass bins (log10 M*/M_sun)
ORDERED_MASS = ["10.2–10.5", "10.5–10.8", "10.8–11.1"]


def load_plateau(path: str) -> pd.DataFrame:
    """Load plateau results and attach RG_mid + sigma (approx 1σ from CI width)."""
    df = pd.read_csv(path)

    # Only keep claimable stacks
    if "claimable" in df.columns:
        df = df[df["claimable"] == True].copy()
    else:
        df = df.copy()

    # Map R_G_bin string -> midpoint in kpc
    df["RG_mid"] = df["R_G_bin"].map(RG_MID)

    # Derive a per-stack sigma from 16–84% CI width
    if {"A_theta_CI_low", "A_theta_CI_high"}.issubset(df.columns):
        sig = (df["A_theta_CI_high"] - df["A_theta_CI_low"]) / 2.0
    else:
        # Fallback: use a small constant if CIs are missing
        sig = pd.Series(1e-3, index=df.index)

    sig = sig.replace([np.inf, -np.inf], np.nan)
    # Replace NaN with median positive sigma, or 1e-3 if all bad
    if (sig > 0).any():
        sig = sig.fillna(sig[sig > 0].median())
        sig[sig <= 0] = sig[sig > 0].min()
    else:
        sig = pd.Series(1e-3, index=df.index)

    df["sigma"] = sig
    return df


def roll_probs(df: pd.DataFrame, draws: int = 200000, seed: int = 42):
    """For each mass bin, estimate P(outer > mid) and mean Δ_out-mid."""
    rng = np.random.default_rng(seed)
    results = []

    for mbin in ORDERED_MASS:
        g = df[df["Mstar_bin"] == mbin]

        # Collect needed size bins: 4.0, 6.5, 10.0 kpc (mids of 3–5, 5–8, 8–12)
        needed = {}
        for _, r in g.iterrows():
            rg_mid = r["RG_mid"]
            if rg_mid in (4.0, 6.5, 10.0):
                needed[rg_mid] = (r["A_theta"], r["sigma"])

        # Need outer (10.0) and both mids (4.0, 6.5)
        if not all(k in needed for k in (4.0, 6.5, 10.0)):
            results.append(
                {
                    "Mstar_bin": mbin,
                    "P_out_gt_mid": np.nan,
                    "Delta_out_mid_mu": np.nan,
                }
            )
            continue

        a4 = rng.normal(needed[4.0][0], needed[4.0][1], draws)
        a65 = rng.normal(needed[6.5][0], needed[6.5][1], draws)
        a10 = rng.normal(needed[10.0][0], needed[10.0][1], draws)

        dom = a10 - 0.5 * (a4 + a65)

        results.append(
            {
                "Mstar_bin": mbin,
                "P_out_gt_mid": float((dom > 0).mean()),
                "Delta_out_mid_mu": float(dom.mean()),
            }
        )

    return results


def main():
    if len(sys.argv) > 1:
        path = sys.argv[1]
    else:
        path = "outputs/lensing_plateau.csv"

    df = load_plateau(path)
    rows = roll_probs(df)

    tag = os.path.splitext(os.path.basename(path))[0]
    print(f"=== {tag} ===")
    for r in rows:
        p = r["P_out_gt_mid"]
        d = r["Delta_out_mid_mu"]
        if np.isnan(p):
            print(f"{r['Mstar_bin']}: insufficient coverage (outer or mids missing)")
        else:
            print(
                f"{r['Mstar_bin']}: "
                f"P(out>mid) = {p:.3f}   "
                f"Δ_out-mid ≈ {d:.4f}"
            )


if __name__ == "__main__":
    main()

